# -*- coding: utf-8 -*-
"""
Created on Thu Aug  1 12:48:16 2024

@author: xinyuan.wei
"""
import os
import pandas as pd
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt

# Select the stage to analyze ('stage_1', 'stage_2', 'stage_3', 'stage_4')
stage = 'stage_4' 

# Stage information
stage_min = 100
stage_max = 140

top_species = 20
stage_label = f'{stage_min}-{stage_max}'

# File to store the PC scores
savefile = f'PCA_Scores_{stage_label}_{stage}.csv'
results_file = 'Composition_Results'

# Define the stages and their corresponding species classifications
stages = {
    'stage_1': {
        'class_1': [12, 743, 761, 746, 129, 375, 241],
        'class_2': [531, 97, 371, 316, 318, 541, 372, 762, 833],
        'class_3': [832, 402, 802, 261]
    },
    'stage_2': {
        'class_1': [94, 241, 12, 375, 97, 972, 743, 531, 746, 371],
        'class_2': [316, 261, 126, 129, 318, 541, 762],
        'class_3': [833, 621, 832]
    },
    'stage_3': {
        'class_1': [241],
        'class_2': [12, 126, 97, 375, 746],
        'class_3': [371, 531, 316, 743, 261, 832, 318, 129, 802],
        'class_4': [837, 541, 833, 762, 621]
    },
    'stage_4': {
        'class_1': [95],
        'class_2': [71, 241],
        'class_3': [126, 97, 375],
        'class_4': [371, 832, 531, 802, 318, 261, 837, 129, 316],
        'class_5': [951, 541, 833, 762],
        'class_6': [621]
    }
}

# Get the species classifications for the selected stage
stage_classes = stages[stage]

# Get the current working directory
current_directory = os.getcwd()

# Get the parent directory and read the data
parent_directory = os.path.dirname(current_directory)
file_path = os.path.join(parent_directory, 'NE_Plot_Composition_Climate.csv')
data = pd.read_csv(file_path)

# Filter the data based on the given conditions
filtered_data = data[(data['INVYR'] > 1999) & 
                     (data['CONDID'] == 1) & 
                     (data['CONDPROP_UNADJ'] >= 0.9) &
                     (data['eAG'] >= 1) &
                     (data['STDAGE'] > stage_min) &
                     (data['STDAGE'] <= stage_max)].copy()

# Separate the climate variables ('AHM' to the end are climate variables)
clm_vars = filtered_data.loc[:, 'AHM':]

# PCA analysis
# Step 1: Standardize the climate variables
scaler = StandardScaler()
clm_vars_scaled = scaler.fit_transform(clm_vars)

# Step 2: Perform PCA on the climate variables
pca = PCA()
pca_scores = pca.fit_transform(clm_vars_scaled)

# Step 3: Add PCA scores to the DataFrame 
pca_columns = [f'PC{i+1}' for i in range(pca_scores.shape[1])]
pca_df = pd.DataFrame(pca_scores, columns=pca_columns, index=filtered_data.index)

# Combine the PCA scores with the original data
filtered_data = pd.concat([filtered_data, pca_df], axis=1)

# Save the PCA scores and components to a CSV file
pca_output_file = os.path.join(parent_directory, results_file, savefile)
filtered_data.to_csv(pca_output_file, index=False)

# Print the explained variance for the first 5 PCs
print('Explained variance by the first 5 principal components:')
for i in range(5):
    print(f'PC{i+1}: {pca.explained_variance_ratio_[i]:.4f}')

# Plot cumulative explained variance
plt.figure(figsize=(8, 6))
plt.plot(range(1, len(pca.explained_variance_ratio_) + 1), 
         pca.explained_variance_ratio_.cumsum(), 
         marker='o', linestyle='--', color='b')
plt.title('Cumulative Explained Variance by Principal Components')
plt.xlabel('Number of Principal Components')
plt.ylabel('Cumulative Explained Variance')
plt.grid(True)
plt.show()

# Print the top 5 variables contributing to the first 5 PCs
loadings = pd.DataFrame(pca.components_.T[:, :5], 
                        columns=[f'PC{i+1}' for i in range(5)], 
                        index=clm_vars.columns)

print('\nTop 5 variables contributing to each of the first 5 principal components:')
for i in range(5):
    print(f'\nTop 5 variables for PC{i+1}:')
    top_5_vars = loadings.iloc[:, i].abs().sort_values(ascending=False).head(5)
    print(top_5_vars)

# Group by dominant species and count occurrences
species_counts = filtered_data['dominant_area_species'].value_counts()

# Select the top dominant species
top_species = species_counts.head(top_species).index

# Filter the data for only the top dominant species
top_species_data = filtered_data[filtered_data['dominant_area_species'].isin(top_species)]

# Focus on PC1 and PC2 for further analysis
pc1_pc2_data = top_species_data[['dominant_area_species', 'PC1', 'PC2']]

# Calculate the mean for PC1 and PC2 and count the number of plots for each species
species_stats = pc1_pc2_data.groupby('dominant_area_species').agg(
    mean_PC1=('PC1', 'mean'),
    mean_PC2=('PC2', 'mean'),
    count=('PC1', 'size')
)

# Map species to colors based on their class
color_map = {}
for class_name, species_list in stage_classes.items():
    for species in species_list:
        if '1' in class_name:
            color_map[species] = 'lightblue'
        elif '2' in class_name:
            color_map[species] = 'blue'
        elif '3' in class_name:
            color_map[species] = 'yellow'
        elif '4' in class_name:
            color_map[species] = 'darkgreen'
        elif '5' in class_name:
            color_map[species] = 'purple'
        elif '6' in class_name:
            color_map[species] = 'silver'

# Plot: PC1 vs PC2 for the top species
plt.figure(figsize=(10, 8))
for species in species_stats.index:
    plt.scatter(species_stats.loc[species, 'mean_PC1'], 
                species_stats.loc[species, 'mean_PC2'], 
                s=species_stats.loc[species, 'count']*10,  
                color=color_map.get(species, 'gray'),  
                alpha=0.6, label=species if species in top_species else "")

for species in species_stats.index:
    plt.text(species_stats.loc[species, 'mean_PC1'], 
             species_stats.loc[species, 'mean_PC2'], 
             species, fontsize=9, ha='right')

plt.title(f'PC1 vs PC2 for Dominant Species in {stage}')
plt.xlabel('PC1')
plt.ylabel('PC2')
plt.grid(True)

# Add a legend showing class labels
handles = [plt.Line2D([0], [0], marker='o', color='w', label=class_name, 
                      markersize=10, markerfacecolor=color_map[species_list[0]])
           for class_name, species_list in stage_classes.items()]
plt.legend(handles=handles, title="Classes")

plt.show()
